import torch
import torch.nn as nn


from train.mlbase import MLBase
from tool.logger import Logger
from model.model_util import disable_running_stats, enable_running_stats

from train.loss.util_loss import XentEC, NegEC


class Trainer(MLBase):

    def __init__(self, p):
        super().__init__(other=p)
        mth = self.args.method
        self._learn = self.ce
        self.criterion = nn.CrossEntropyLoss()
        if self.criterion is not None:
            self.criterion = self.criterion.cuda()
        self.log = Logger()

    def __call__(self, x, y, x_p):
        if x_p is not None:
            return self._learn(x, y, x_p)
        else:
            return self._learn(x, y)

    def ce(self, x, y):
        logits = self.model(x)
        l_cls = self.criterion(logits, y)
        return l_cls

